"""
This class extends the ProdFun in the following
respects:
    - Specify which variables to include in the first stage
    - Specify which variables to use as inputs, among those
        * which to be static (eg fixed costs in DLEU19)
        * which to be dynamic (eg capital in DLEU19)
    - Specify which inputs have a markup:
        * fixed (no markup, eg capital in DLEU19)
        * variable (yes markup, eg labor in DLEU19)
The structure of the class is left untouched.

@author: Giovanni Morzenti (giomorzent@live.it)
@author:  Burstein, Carvalho, Grassi

"""
############################
# Import modules & Packages
############################
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as pltcolors

#plt.rcParams['text.usetex'] = False; plt.rcParams['text.latex.unicode'] = False # this is needed on the BOX

from mpl_toolkits.mplot3d import Axes3D
import warnings
from scipy import stats 
from scipy import optimize
from time  import time


#import statsmodels.api as sm
############################
# Define the class
############################
class ProdFun:
    """
    This class estimates a Production Function
    using the procedure by Ackerberg, Caves & Frazer (2007,2015)
    which is a variation over Olley & Pakes (1996) and Levinsohn
    & Petrin (2003).
    
    The procedure is in 2 stages:
        1) Estimate non-parametrically output against a set of variables. These
           always include the production inputs, and possibly some other controls;
        2) Construct a GMM objective function w/n-moments, n being the number of 
           production inputs, and minimize it.
    
    Ultimately, the class estimates mark-ups (for variable inputs only)
    à la De Loecker, Eeckhout & Unger (2019), who build on De Loecker & Warzinsky (2012).
    
    Input
    -----------------
    data: pandas.dataframe ['firmid', 'date', 'varcost', 'capital', 'sale', 'sector_2d', 'v', 'k', 'y']
        ['firmid']: identifier of firms (GVKEY in Compustat)
        ['date']: date of each observation, can be different from yearly
        ['varcost']: deflated variable input (COGS in Compustat) 
        ['capital']: deflated capital (PPENT in Compustat) 
        ['sale']: deflated sales (SALE in Compustat)
        ['sector_2d']: industry
        ['v']: log of 'varcost'
        ['k']: log of 'capital'
        ['y']: log of 'sale'
        ['x']: log of variable (optional, could be any other variable that one to put)
            
    vars_how: dict, optional (default = {'k':['purge','dynamic','fixed'], 'v':['purge','static','variable']})
        Select inputs and provide options for thier use in the estimation procedure    
        
        {       'name'        : [   'purge'/'no purge'    ,      'no input'/'static'/'dynamic'      ,     'no input'/'variable'/'fixed']}
            name of variable         include in purge        input: not included,  static, dynamic            markup: no/yes/no
        
        ADVANCED: if you wish to select the order of the inputs you can use an ordered dictionary (see OrderedDict())
    
    M: int, optional (default = 3)
        order of interaction terms for inputs used in purging regression
        
    purging: bool, optional (default = True)
        True: purge sales by a regression with multiple inputs interactions
        False: no purging, and consequently no sales correction
        
    phi_true: string, optional (default = '')
        '': compute \phi in the first stage using the standard procedure
        'var': use self.data['var'] as true value of \phi, instead of computing it withthe first stage
        NOTE: it works only if purging = False
        
    FE: list(str), optional (default = [])
        Select which variable to use as FE in purging and initial values of GMM
        []: No fixed effects
        ['date']: year fixed effects
        ['firmid_FE']: firm fixed effects (noticeably slower)
        ['date','firmid_FE']: firm and date fixed effects (noticeably slower)
        
    FE_demean: boolean, optional (default = True)
        True: use demeaning for Fixed Effects
        False: use dummy variables for Fixed Effects
        
    init_FE: boolean, optional (default = False)
        True: add selected FE to the initial values OLS eqaution

    translog: bool, optional (default = True)
        True: adding interactions terms with same power (up to 2, e.g. v^2) 
              allowing for non-linearity in the production function
    
    tl_inter: bool, optional (default = False)
        True: add interaction terms of power 1 between different inputs (e.g. v*k)
        
    GMM_cons: bool, optional (default = True)
        True: include the constant in the GMM estimation
        
    beta_init: np array, optional (default = [])
        []: compute initial values with OLS
        Otherwise this allows to specify the intial values of the GMM estimation. For example:
        np.array([1, 0.4, 0.6])
        
    delta_simplex: float, optional (default = 0)
        for the 'NM' optimization method
        defines the coefiicient that multiplies the base to form each vertex of the simplex
        0: use the default simplex in the NM algorithm
        d: consider p as initial value, the simplex will be [p, p+[d,0,..,0], p+[0,d,0,..,0], ..., p+[0,...,0,d]]
        
        ADVANCED: one can specify the verteces of the simplex one by one, in all its dimensions
        [d0,d1,...,dn]: the simplex will be [p, p+[d0,0,..,0], p+[0,d1,0,..,0], ..., p+[0,...,0,dn]]
        
    beta_true: np array, optional (dfault = [])
        [] no value specified, no value will be shown in the heatmap
        specify the true value of betas, e.g. np.array([0.3, 0.7])
        
    init_control: list of strings, optional (default = [])
        include additional controls in the OLS regression to compute initial values    
        []: include no additional regressior
        ['p','ms']: include variables 'p' and 'ms' as additional regressors
        
    optim: str, optional (default = 'NM')
        Select which optimizer to use
        'NM': for the Nelder-Mead simplex algorithm
        'BFGS': for quasi-Newton method of Broyden, Fletcher, Goldfarb, and Shanno
        'FSOLVE': for MINPACK’s hybrd and hybrj algorithms to find roots of non linear system of equations 
        'BASINHOPPING': for the basin-hopping algorithm (global minimum)
        
        ADVANCED: Try several optimizers if 'NM' does not converge
        'Iter': Executes 'NM',
            if the optimization is not successfull executes 'FSOLVE',
            if the optimization is not successfull executes 'BASINHOPPING'
    
    AR_c: bool, optional (default = True)
        True: Includes the constant in the AR used to compute moments for GMM
        
    AR_sqlag: bool, optional (default = False)
        True: Includes squared lag in the AR used to compute moments for GMM
        
    AR_price: bool, optional (default = False)
        True: Include price and lag of price as control in the AR used to compute moments for GMM
        
    omega_price: bool, optional (default = Fales)
        True: correct omega in moments() method using price
        
    demean: bool, optional (default = False)
        True: demeans variables before first stage, initial values and GMM
        
    NormMoments: bool, optional (default = False)
        True: Renormalize moments so as to have correlation rather than covariance

    LogMoments: bool, optional (default = False)
        True: Take log of moments before giving them to the solver
        NOTE: This is not implemented for FSOLVE
        
    RTS_const: tuple, optional (default = (False,''))
        True: Impose constant return to scale
        '': no option, allowed only with False
        'explicit': explicitly change the functional form, by reducing the number of parameters
        'constraints': impose constraints on parameters
        
    subsample: string, optional (default = '')
        '': use the entire sample for the estimation
        'xxx': use the variable 'xxx' in the panda dataframe to identify a subsample on which
               to conduct the estimation. Then estimated elasticities are applied to the whole 
               dataset in order to compute markups for every observation.
               NOTE: variable 'xxx' must be boolean and it has to be True for subsample obs
            
    plot_moments_fn: bool, optional (default = False)
        True: plot the moment function in 3D
        NOTE: it works only with two inputs (both CD and TL), since the plot is in 3D
              for the TL case, it plots the first order coefficients
    
    markups: bool, optional (default = True)
        True: computes firm specific markups and stores them in self.dta_mu
        False: does not compute markups (recommended for low memory usage)
    
    sales_correction: bool, optional (default = True)
        True: applies sales correction in the computation of markups
    
    verbose: bool, optional (default = True)
        True: Displays several metrics and results of the fitting
        False: Displays nothing, and stores all results in the object        
    
    se : bool, (default = False)
        True: compute bootstrap SEs for betas (elasticities)
    
    alpha_se : float, (default = .05)
        confidence level for betas CI. Must be in [0, 1].
        Ignored if `se=False`
    
    nb : int, (default = 200)
        number of boostrap repetitions for SEs. 
        Ignored if `se=False`     

    plot_moments_fn_param: Dict (default = {'ngrid': 200, 'NbNorm': 20, 'xlim': [-1, 1], 'ylim': [-1, 1], 'mesh': False})
        ngrid: int number of grid point by dimension
        NbNorm: int number of color class
        xlim: array 2*1 boundaries along the x-axis
        ylim: array 2*1 boundaries along the y-axis
        mesh: bool do you want to use a pcolormesh or contourf
                                
    
    Output
    -----------------
    self.dta: pandas.dataframe
        Dataframe used in the estimation, containing all interaction and lagged variables
        in this dataframe the first year is dropped for each firm due to lagged variables
                
    self.dta_mu: pandas.dataframe
        ['date','sale','varcost','sector_2d','epsilon','mu']
        Dataframe containing markups ['mu'], computed per each firm-date observation
        Markups are computed also for the years dropped in self.dta
    
    self.b_ols: np.array
        Initial values for elasticities estimated with GMM
    
    self.betas: np.array
        Elasticities estimated with GMM
        
    self.conv: tuple
        Full output of the optimizer (contains useful informations on convergence)
        
    self.plt_3D: figure
        3D surface of the moment function (you can save the figuere using self.plt_3D.savefig('fig.png'))
    
    self.plt_heatmap: figure
        Heatmap of the moment function (you can save the figuere using self.plt_heatmap.savefig('fig.png'))
    
    self.ses : numpy 1D array, bootstrap SEs for betas (elasticities)
    
    self.CI : numpy 2D array,  alpha-level Normal bootstrap 
         CI for betas.  1st column is lower endpoint CI 
         and 2nd column is upper endpoint CI.

    """
    def __init__(self,
                 data,
                 vars_how = {'v':['purge','static','variable'], 'k':['purge','dynamic','fixed']},
                 M=3,
                 purging=True,
                 phi_true='',
                 FE=[],
                 FE_demean = True,
                 init_FE = True,
                 translog = True,
                 tl_inter = False,
                 GMM_cons = True,
                 beta_init = [],
                 delta_simplex = 0,
                 beta_true = [],
                 init_control = [],
                 optim='NM',
                 AR_c=True,
                 AR_sqlag=False,
                 AR_price = False,
                 omega_price = False,
                 demean=False,
                 NormMoments=False,
                 LogMoments=False,
                 RTS_const=(False,''),
                 RTS=1,
                 subsample='',
                 plot_moments_fn = False,
                 plot_moments_fn_coeff = [0,1],
                 markups=True,
                 sales_correction=True,
                 verbose=True,
                 se=False,
                 alpha_se=.05,
                 nb=100,
                 plot_moments_fn_param={'ngrid': 100, 'NbNorm': 20, 'xlim': [], 'ylim': [], 'mesh': False}):
        
        # check some errors in vars_how        
        for key in vars_how:
            if len(vars_how[key]) != 3:
                raise KeyError('Please specify all necessary information for {} (see help).'.format(key))
            
            for element in vars_how[key]:
                if element not in ['no input', 'purge', 'no purge', 'variable', 'fixed', 'static', 'dynamic']:
                    raise KeyError('Keyword in {} is not valid (see help)'.format(key))
                    
            if vars_how[key][0] == 'no purge':
                
                if vars_how[key][1] != 'no input':
                    raise ValueError('{} is not included in the first stage so it cannot be an input.'.format(key))
                    
                elif vars_how[key][2] != 'no input':
                    raise ValueError('{} is not included in the first stage so it can be neither a variable nor a fixed input.'.format(key))
                    
            elif vars_how[key][0] == 'purge' and vars_how[key][1] != 'no input':
                
                if vars_how[key][2] == 'no input':
                    raise ValueError('{} is an input, specify whether it is fixed or variable.'.format(key))

        """
        Functions are executed here.
        Workflow below.
        """       

        # store parameters
        self.data = data
        self.vars_how = vars_how
        self.M = M
        self.purging = purging
        self.phi_true = phi_true
        self.FE = FE
        self.FE_demean = FE_demean
        self.init_FE = init_FE
        self.translog = translog
        self.tl_inter = tl_inter
        self.GMM_cons = GMM_cons
        self.beta_init = beta_init
        self.delta_simplex = delta_simplex
        self.beta_true = beta_true
        self.init_control = init_control
        self.optim = optim
        self.AR_c = AR_c
        self.AR_sqlag = AR_sqlag
        self.AR_price = AR_price
        self.omega_price = omega_price
        self.demean = demean
        self.NormMoments = NormMoments
        self.LogMoments = LogMoments
        self.RTS_const = RTS_const
        self.RTS = RTS
        self.subsample = subsample
        self.plot_moments_fn = plot_moments_fn
        self.plot_moments_fn_coeff = plot_moments_fn_coeff
        self.markups = markups
        self.sales_correction = sales_correction
        self.verbose = verbose
        self.se = se
        self.alpha_se = alpha_se
        self.nb = nb
        self.plot_moments_fn_param=plot_moments_fn_param


        # load the data
        self.dta = data
        # Initial cleaning and purging of sales
        self.purge()
        # initial values
        self.initial_values()
        # initialize GMM
        self.initialize_GMM()
        
        # fit GMM
        self.fit_GMM()       
        
        # plot moment function
        if self.plot_moments_fn:
            self.plot_moments()
        
        if self.markups:
            # compute markups
            self.compute_markups()        
            # clean the dataset for output
            self.clean_dta_mu()
            
        # compute SEs and CI for betas
        if self.se:
            self.boot_se()

        
    #----------------------------------------------------------#  
    def purge(self):
        """
        This is the first stage in ACF(2007), namely to regress
        firm level output against a set of variables included
        in the purge.
        Regression here is non-parametric using a M-order
        polynomial. 
        """

        with warnings.catch_warnings(): # suppress wornings of log(0)
            warnings.simplefilter("ignore")
            self.dta['firmid_FE'] = self.dta['firmid']
            self.dta['date_FE'] = self.dta['date']
        
        #set up the panel data structure
        self.dta = self.dta.set_index(['firmid','date'], drop=False)
        #self.dta = self.dta.set_index(['firmid','date'], drop=True)
        self.dta = self.dta.drop(columns=['firmid'])
        
        # drop missing obs
#        self.dta = self.dta.replace([np.inf, -np.inf], np.nan).dropna()
        
        # store dataset for computation of firm level markups        
        if self.markups:
            self.dta_mu = self.dta
            
        # restrict estimation to a subsample of the entire dataset
        if self.subsample!='':
            self.dta = self.dta[self.dta[self.subsample]==True]
        
        # create the constant for OLS estimation         
        self.dta['const'] = 1 # add constant
    
        if self.purging == True:            
            # select which variables to include in purging stage
            vars_purging = []
            vars_purging2 = []
            for key in self.vars_how:
#                if vars_how[key][0] == 'purge':
                if self.vars_how[key][0] == 'purge' and self.vars_how[key][1] != 'no input':
                    vars_purging.append(key)
                    vars_purging2.append(key)
                    
            # create interactions
            self.interlist = []
            
            if self.GMM_cons:
                self.interlist.append('const')
            
            # additional controls
            for key in self.vars_how:
                if self.vars_how[key][0] == 'purge' and self.vars_how[key][1] == 'no input':
                    self.interlist.append(key) # just the control

#                    for i in range(2,self.M+1): # higher power for additional controls
#                        self.dta[key + str(i)] = self.dta[key].pow(i)
#                        self.interlist.append(key + str(i))
                        
            # interaction terms
#            for var in vars_purging:
#                for i in range(1,self.M+1):
#                    self.dta[var + str(i)] = self.dta[var].pow(i)
#                    self.interlist.append(var + str(i))
#                    for var2 in vars_purging2:
#                        if var2 != var:
#                            for j in range(1,self.M+1):
##                                self.dta[var + str(i) + var2 + str(j)] = self.dta[var].pow(i)*self.dta[var2].pow(j)
##                                self.interlist.append(var + str(i) + var2 + str(j))
#                                if i+j<=self.M:
#                                    self.dta[var + str(i) + var2 + str(j)] = self.dta[var].pow(i)*self.dta[var2].pow(j)
#                                    self.interlist.append(var + str(i) + var2 + str(j))
#                                    
#                vars_purging2.remove(var)
             
            # interaction terms
            vars_purging2=vars_purging.copy()
            for var in vars_purging:
                for i in range(1,self.M+1):
                    self.dta[var + str(i)] = self.dta[var].pow(i)
                    self.interlist.append(var + str(i))
                    vars_purging3=vars_purging2.copy()
                    for var2 in vars_purging2:
                         if var2 != var:
                             for j in range(1,self.M+1):
                                 if i+j<=self.M:
                                     self.dta[var + str(i) + var2 + str(j)] = self.dta[var].pow(i)*self.dta[var2].pow(j)
                                     self.interlist.append(var + str(i) + var2 + str(j))
                                 for var3 in vars_purging3:
                                     if var3 != var and var3 != var2:
                                         for k in range(1,self.M+1):
                                             if i+j+k <= self.M:
                                                 self.dta[var + str(i) + var2 + str(j)+ var3 + str(k)] = self.dta[var].pow(i)*self.dta[var2].pow(j)*self.dta[var3].pow(k)
                                                 self.interlist.append(var + str(i) + var2 + str(j)+ var3 + str(k))
                         vars_purging3.remove(var2)
                vars_purging2.remove(var)
                    
            
            
            # Prepare matrices for OLS regression
            Y = self.dta[['y']]

            
            X = self.dta[self.interlist+self.FE]
            
            
            
            if self.FE_demean: # use demeaning for FE
                X = X.drop(columns=self.FE)
                X_pred = np.array(X) # store it to compute phi and epsilon
                Y_pred = np.array(Y) # store it to compute phi and epsilon
                if self.FE != []: # add desired fixed effects
#                    X = X - X.mean(level=self.FE) # demean the variable
#                    Y = Y - Y.mean(level=self.FE) # demean the variable
                    if len(self.FE)<3:
                        for var_FE in self.FE:
                            X = X - X.replace([np.inf, -np.inf], np.nan).dropna().mean(level=var_FE) # demean the variable
                            Y = Y - Y.replace([np.inf, -np.inf], np.nan).dropna().mean(level=var_FE) # demean the variable
                    else: print('Impossible to demean with more than 2 FE levels')
            
            else: # use dummy variables for FE
                if self.FE != []: # add desired fixed effects
                    if self.GMM_cons:
                        dummies = pd.get_dummies( sum(X[self.FE].values.tolist(),[]) , drop_first=True )
                    else:
                        dummies = pd.get_dummies( sum(X[self.FE].values.tolist(),[]) , drop_first=False )
                    X = pd.concat([X.reset_index(drop=True),dummies.reset_index(drop=True)], axis=1)
                    Y = Y.reset_index(drop=True)
                    X = X.drop(columns=self.FE)
                X_pred = np.array(X) # store it to compute phi and epsilon
                Y_pred = np.array(Y) # store it to compute phi and epsilon

            C=pd.concat([X,Y],axis=1).replace([np.inf, -np.inf], np.nan).dropna() # drop missing values
            X = C.drop(columns=['y']); Y = C[['y']] # get back X and Y
                
            X = np.array(X)
            Y = np.array(Y) # turn Y into an array as well
            
            if self.demean:
                Y = Y - np.nanmean(Y)
#                Y_pred = Y_pred - np.nanmean(Y_pred)
                if self.FE != [] and self.FE_demean==False:
                    if self.GMM_cons:
                        X[:,1:len(self.interlist)] = X[:,1:len(self.interlist)] - np.nanmean(X[:,1:len(self.interlist)], axis=0)
                    else:
                        X[:,0:len(self.interlist)] = X[:,0:len(self.interlist)] - np.nanmean(X[:,0:len(self.interlist)], axis=0)
#                    X_pred[:,0:len(self.interlist)] = X_pred[:,0:len(self.interlist)] - np.nanmean(X_pred[:,0:len(self.interlist)], axis=0)
                elif self.FE_demean:
                    pass
                else:
                    X = X - np.nanmean(X, axis=0)
            
            
            #Perform OLS regression by hand
            #betas = np.linalg.inv(X.T @ X) @ (X.T @ Y) #old formula
            betas = np.linalg.pinv(X) @ Y
            predict = X_pred @ betas
            
            #Perform OLS regression using external package 
            #model = sm.OLS(Y,X)
            #results = model.fit()            
            #predict = np.array([results.predict()]).T
            
            
            #residual
            epsilon = Y_pred - predict

            self.Y_purge = Y_pred
            self.Y_predict = predict
            

            
            #self.Rsq_purge = 1 - ( np.var(epsilon[np.logical_not(np.isnan(epsilon))]) / np.var(Y[np.logical_not(np.isnan(Y))]) )
            self.Rsq_purge = 1 - ( np.var(epsilon[np.logical_not(np.isnan(epsilon))]) / np.var(Y_pred[np.logical_not(np.isnan(Y_pred))]) )
            
            
            if self.verbose:
                print('First Stage R squared = ' + str(self.Rsq_purge))
            
            self.dta['phi'] = predict
            self.dta['epsilon'] = epsilon
        
        else:
            self.dta['phi'] = self.dta['y'] 
            self.dta['epsilon'] = 0
            
            if self.phi_true!='':
                print('No first stage, phi substituted with ' + self.phi_true)
                self.dta['phi'] = self.dta[self.phi_true] 
                self.dta['epsilon'] = self.dta['y'] - self.dta[self.phi_true]
                
    
   #----------------------------------------------------------# 
    def initial_values(self):
        """
        Construct the initial conditions to give to the GMM optimizer.
        """

        # drop missing obs
#        self.dta = self.dta.replace([np.inf, -np.inf], np.nan).dropna()
        
        #sort data by firmsId and year
        self.dta=self.dta.sort_values(by=['firmid_FE','date_FE'])

        # constuct lags of purged sales
        self.dta['phi_lag'] = self.dta.groupby(['firmid'])['phi'].shift(1)
        
        # lag of controls for GMM
        if self.AR_price or self.omega_price:
            self.dta['p_lag'] = self.dta.groupby(['firmid'])['p'].shift(1)
        
        
        # constuct lags of inputs
        for key in self.vars_how:
            if self.vars_how[key][1] != 'no input':
                self.dta[key + '_lag'] = self.dta.groupby(['firmid'])[key].shift(1)
        
        if self.translog == True:
            
            var_list_support = [] # list needed for interaction terms
            for key in self.vars_how:
                if self.vars_how[key][1] != 'no input':
                    var_list_support.append(key)
                    
            for key in self.vars_how:
                if self.vars_how[key][1] != 'no input':
                    
                    # construct square terms
                    self.dta[key + '2'] = self.dta[key].pow(2)
                    self.dta[key + '_lag2'] = self.dta[key + '_lag'].pow(2)
                    
                    if self.tl_inter: # construct interaction terms
                        for key2 in var_list_support: 
                            if key2 != key:
                                self.dta[key + key2] = self.dta[key] * self.dta[key2]
                                self.dta[key + '_lag' + key2] = self.dta[key+'_lag'] * self.dta[key2]
                                self.dta[key + key2 + '_lag'] = self.dta[key] * self.dta[key2+'_lag']
                                self.dta[key + '_lag' + key2 + '_lag'] = self.dta[key+'_lag'] * self.dta[key2+'_lag']
                        
                        var_list_support.remove(key) # drop variable from list, so as not to duplicate terms
                            
        # drop missing obs (one obs per firm)
        self.dta = self.dta.replace([np.inf, -np.inf], np.nan).dropna()


        ## vector of variables to be put in the initial value OLS
        
        self.variables = []
        
        if self.GMM_cons:
            self.variables.append('const')
        
        for key in self.vars_how:
            if self.vars_how[key][1] != 'no input':
                self.variables.append(key)
                
        if self.translog == True: # include the translog function specification
            
            var_list_support = [] # list needed for interaction terms
            for key in self.vars_how:
                if self.vars_how[key][1] != 'no input':
                    var_list_support.append(key)
            
            for key in self.vars_how:
                
                if self.vars_how[key][1] != 'no input':
                    self.variables.append(key+'2')
                    
                    if self.tl_inter: # construct interaction terms
                        for key2 in var_list_support: 
                            if key2 != key:
                                self.variables.append(key+key2)
                                
                        var_list_support.remove(key) # drop variable from list, so as not to duplicate terms


        
#        if self.beta_init==[]:
        Y = self.dta[['y']]
    
        X = self.dta[self.variables+self.init_control+self.FE]
#        X = self.dta[self.variables+self.init_control]
        
        nb_var = X.shape[1] -len(self.init_control) - len(self.FE) # getting the number of variables
#        nb_var = X.shape[1] -len(self.init_control) # getting the number of variables
        

        if self.init_FE:
            if self.FE_demean:
                X = X.drop(columns=self.FE)
                if self.FE != []: # add desired fixed effects
#                    X = X - X.mean(level=self.FE) # demean the variable
#                    Y = Y - Y.mean(level=self.FE) # demean the variable
                    if len(self.FE)<3:
                        for var_FE in self.FE:
                            X = X - X.mean(level=var_FE) # demean the variable
                            Y = Y - Y.mean(level=var_FE) # demean the variable
                    else: print('Impossible to demean with more than 2 FE levels')
            else:
                if self.FE != []: # add desired fixed effects
                    if self.GMM_cons:
                        dummies = pd.get_dummies( sum(X[self.FE].values.tolist(),[]) , drop_first=True )
                    else:
                        dummies = pd.get_dummies( sum(X[self.FE].values.tolist(),[]) , drop_first=False )
                    X = pd.concat([X.reset_index(drop=True),dummies.reset_index(drop=True)], axis=1)
                    Y = Y.reset_index(drop=True)
                    X = X.drop(columns=self.FE)


        C=pd.concat([X,Y],axis=1).dropna() # drop missing values
        X = C.drop(columns=['y']); Y = C[['y']] # get back X and Y

        X = np.array(X)
        Y = np.array(Y)

        if self.demean:
            Y = Y - np.nanmean(Y)
            if self.init_FE and self.FE_demean==False:
                if self.GMM_cons:
                    X[:,1:len(self.interlist)] = X[:,1:len(self.interlist)] - np.nanmean(X[:,1:len(self.interlist)], axis=0)
                else:
                    X[:,0:len(self.interlist)] = X[:,0:len(self.interlist)] - np.nanmean(X[:,0:len(self.interlist)], axis=0)
#                X[:,0:len(self.variables+self.init_control)] = X[:,0:len(self.variables+self.init_control)] - np.nanmean(X[:,0:len(self.variables+self.init_control)], axis=0)
            elif self.FE_demean:
                pass
            else:
                X = X - np.nanmean(X, axis=0)


        #b_ols = np.linalg.inv(X.T @ X) @ (X.T @ Y)
        b_ols = np.linalg.pinv(X)  @ Y 
        
        if self.RTS_const[0] and self.RTS_const[1]=='explicit': # apply Traina constant return to scale correction
            if self.translog:
                if self.GMM_cons:
                    self.b_ols = np.squeeze(b_ols[[0,1,3],])
                else:
                    self.b_ols = np.squeeze(b_ols[[0,2],])
            else:
                if self.GMM_cons:
                    self.b_ols = np.squeeze(b_ols[[0,1],])
                else:
                    self.b_ols = np.squeeze(b_ols[[0],])
        else:
            self.b_ols = np.squeeze(b_ols[:nb_var,])
            
        # substitute the initial values specified manually in beta_init
        self.b_ols = np.array( list(self.beta_init)  + list(self.b_ols[len(self.beta_init):]) ) 


        if self.verbose and len(self.beta_init)==0:
            print('Initial values computed with OLS:')
            print(self.b_ols)
        if self.verbose and len(self.beta_init)>0 and len(self.beta_init)<len(self.variables):
            print('Initial values:')
            print(self.b_ols)
            print('Out of which the following were imputed:')
            print(self.beta_init)
        if self.verbose and len(self.beta_init)==len(self.variables):
            print('Imputed Initial values:')
            print(self.beta_init)


    #----------------------------------------------------------#
    def initialize_GMM(self):
        """
        Construct the stuff that are fed to the GMM
        nonlinear solver.
        Note that whether an input is static or 
        dynamic impacts on the timing of the moment condition.
        """
        
        
	 
        ## Construct matrices for GMM
        
        # construct the matrix of instruments
        # Keep an eye on the right moment condition
        instruments = []
        
        if self.GMM_cons:
            instruments.append('const')
        
        for key in self.vars_how:
            if self.vars_how[key][1] == 'static':
                if self.RTS_const[0]:
                    pass
                else:
                    instruments.append(key + '_lag')
            elif self.vars_how[key][1] == 'dynamic':
                instruments.append(key)
            else:
                pass
        
        if self.translog == True:
            
            var_list_support = [] # list needed for interaction terms
            for key in self.vars_how:
                if self.vars_how[key][1] != 'no input':
                    var_list_support.append(key)
                    
            for key in self.vars_how:
                
                if self.vars_how[key][1] == 'static':
                    if self.RTS_const[0]:
                        pass
                    else:
                        instruments.append(key + '_lag2')
                elif self.vars_how[key][1] == 'dynamic':
                    instruments.append(key + '2')
                else:
                    pass
                
                if self.RTS_const[0]==False and self.tl_inter and self.vars_how[key][1] != 'no input': # construct interaction terms
                    for key2 in var_list_support: 
                        if key2 != key:
                            
                            if self.vars_how[key][1] == 'static':
                                first = key + '_lag'
                            elif self.vars_how[key][1] == 'dynamic':
                                first = key
                            else:
                                pass
                            
                            if self.vars_how[key2][1] == 'static':
                                second = key2 + '_lag'
                            elif self.vars_how[key2][1] == 'dynamic':
                                second = key2
                            else:
                                pass
                            
                            instruments.append(first + second)
                            
                    var_list_support.remove(key) # drop variable from list, so as not to duplicate terms
                    
#        print('Instruments:')
#        print(instruments)
#                
        #instruments.append('phi_lag') # add phi_lag as instrument, as in Ackerberg et al. (ECMA 2015)
        self.Z = np.array(self.dta[instruments])
        
        if self.demean:
            if self.GMM_cons:
                self.Z[:,1:] = self.Z[:,1:] - np.nanmean(self.Z[:,1:], axis=0)
            else:
                self.Z = self.Z - np.nanmean(self.Z, axis=0)
#            self.Z = self.Z - np.nanmean(self.Z, axis=0)
        
        
        ## vector of variables in the production function estimated with GMM
        
        self.variables = []
        lagged_variables = []
        
        if self.GMM_cons:
            self.variables.append('const')
            lagged_variables.append('const')
        
        
        for key in self.vars_how:
            if self.vars_how[key][1] != 'no input':
                self.variables.append(key)
                lagged_variables.append(key+'_lag')
                
        if self.translog == True: # include the translog function specification
            
            var_list_support = [] # list needed for interaction terms
            for key in self.vars_how:
                if self.vars_how[key][1] != 'no input':
                    var_list_support.append(key)
            
            for key in self.vars_how:
                
                if self.vars_how[key][1] != 'no input':
                    self.variables.append(key+'2')
                    lagged_variables.append(key+'_lag2')
                    
                    if self.tl_inter: # construct interaction terms
                        for key2 in var_list_support: 
                            if key2 != key:
                                self.variables.append(key+key2)
                                lagged_variables.append(key+'_lag'+key2+'_lag')
                                
                        var_list_support.remove(key) # drop variable from list, so as not to duplicate terms

#        print('Variables:')
#        print(self.variables)
#        print(lagged_variables)

        self.X = np.array(self.dta[self.variables])
        self.X_lag = np.array(self.dta[lagged_variables])
        
#        print(self.dta[self.variables].head(5))
#        print(self.X[0:4][:])
        
        
        # dependent variable for initial OLS values
        self.Y = np.array([np.array(self.dta['y'])]).T
        
        # constant
        self.C = np.array([np.array(self.dta['const'])]).T
        # dependent variable of the GMM
        self.phi = np.array([np.array(self.dta['phi'])]).T
        self.phi_lag = np.array([np.array(self.dta['phi_lag'])]).T
        
        # additional control in the GMM
        if self.AR_price or self.omega_price:
            self.AR_control = np.array([np.array(self.dta['p'])]).T
            self.AR_control_lag = np.array([np.array(self.dta['p_lag'])]).T
#            self.AR_control = np.array([np.array(self.dta['marketshare_err'])]).T
#            self.AR_control_lag = np.array([np.array(self.dta['marketshare_err'])]).T
            
#        print(self.phi.mean())
            
        if self.demean:
            if self.GMM_cons:
                self.X[:,1:] = self.X[:,1:] - np.nanmean(self.X[:,1:], axis=0)
                self.X_lag[:,1:] = self.X_lag[:,1:] - np.nanmean(self.X_lag[:,1:], axis=0)
            else:
                self.X = self.X - np.nanmean(self.X, axis=0)
                self.X_lag = self.X_lag - np.nanmean(self.X_lag, axis=0)
#            self.X = self.X - np.nanmean(self.X, axis=0)
#            self.X_lag = self.X_lag - np.nanmean(self.X_lag, axis=0)
            self.phi = self.phi - np.nanmean(self.phi, axis=0)
            self.phi_lag = self.phi_lag - np.nanmean(self.phi_lag, axis=0)
            
#        print(self.phi.mean())
        
        
    
    #----------------------------------------------------------# 
    def moments(self, betas_1D):
        """
        Builds the GMM objective function that needs be minimized.
        """
        betas = np.array([betas_1D]).T
        
#        print("")
#        print(betas_1D)
#        print(betas)
#        
        if self.RTS_const[0] and self.RTS_const[1]=='explicit': # apply Traina constant return to scale correction
            if self.GMM_cons:
                if self.translog:
                    betas = np.array([[betas[0,0]],[betas[1,0]],[self.RTS-betas[1,0]],[betas[2,0]],[-2*betas[2,0]],[betas[2,0]]])
                else:
                    betas = np.array([[betas[0,0]],[betas[1,0]],[self.RTS-betas[1,0]]])
            else:
                if self.translog:
                    betas = np.array([[betas[0,0]],[self.RTS-betas[0,0]],[betas[1,0]],[-2*betas[1,0]],[betas[1,0]]])
                else:
                    betas = np.array([[betas[0,0]],[self.RTS-betas[0,0]]])
                
        omega = self.phi - self.X @ betas
        omega_lag = self.phi_lag - self.X_lag @ betas
        
        if self.omega_price:
            omega = omega - self.AR_control
            omega_lag = omega_lag - self.AR_control_lag
        
        omega_lag_pol = omega_lag
        if self.AR_c: # include the constant
            omega_lag_pol = np.concatenate((self.C,omega_lag_pol),axis=1) # adding the constant to AR
        if self.AR_sqlag: # include the squared lag
            omega_lag_pol = np.concatenate((omega_lag_pol,np.square(omega_lag)),axis=1) # adding squared lag to AR
        if self.AR_price:
            omega_lag_pol = np.concatenate((omega_lag_pol,self.AR_control),axis=1) # adding additional controls to AR
            omega_lag_pol = np.concatenate((omega_lag_pol,self.AR_control_lag),axis=1) # adding additional controls to AR
        
        g_b = np.linalg.inv(omega_lag_pol.T @ omega_lag_pol) @ (omega_lag_pol.T @ omega)
        ########
        #Save the AR process parameters
        self.g_b = g_b
        ######
        
        xi = omega - omega_lag_pol @ g_b
        
        if self.NormMoments: # normalize moments so as to have correlation, rather than convariance
            moments = (self.Z.T @ xi) / ((np.array([np.diagonal(self.Z.T @ self.Z)]).T)**0.5 * (xi.T @ xi)**0.5)  
#            moments = ( (self.Z.T @ xi)**2 / ((np.array([np.diagonal(self.Z.T @ self.Z)]).T) * (xi.T @ xi)) )**0.5  
#            moments = (self.Z.T @ xi)**2 / ((np.array([np.diagonal(self.Z.T @ self.Z)]).T) * (xi.T @ xi)) 
            #moments = np.log(moments)
        else: # leave moments as covariance
            moments = self.Z.T @ xi
            #moments = np.log(moments)
            
        
        if self.optim=='FSOLVE': # FSOLVE requires one value to bring to 0 per each moment
            if self.RTS_const[0] and self.RTS_const[1]=='constraints': # apply Traina constant return to scale correction
                if self.GMM_cons:
                    if self.translog:
                        return np.append(np.squeeze(moments), np.array((self.RTS-betas[1]-betas[2])**2),np.array(np.array((2*betas[3]+betas[4])**2)),np.array((2*betas[5]+betas[4])**2))  
                    else:
                        return np.append(np.squeeze(moments), np.array((self.RTS-betas[1]-betas[2])**2))  
                else:
                    if self.translog:
                        return np.append(np.squeeze(moments), np.array((self.RTS-betas[0]-betas[1])**2),np.array(np.array((2*betas[2]+betas[3])**2)),np.array((2*betas[4]+betas[3])**2))  
                    else:
                        return np.append(np.squeeze(moments), np.array((self.RTS-betas[0]-betas[1])**2))  
                    
            else:
                return np.squeeze( moments )
            
        elif self.optim=='SLSQP' or self.optim=='trust-constr':
            if self.LogMoments:
                return np.log( np.squeeze( moments.T @ moments ) )
            else:
                return np.squeeze( moments.T @ moments )
        
        else: # other solvers require just one positive value to minimize
            if self.RTS_const[0] and self.RTS_const[1]=='constraints': # apply Traina constant return to scale correction
                if self.GMM_cons:
                    if self.translog:
                        return (np.squeeze( moments.T @ moments ) + (self.RTS-betas[1,0]-betas[2,0])**2 + (2*betas[3,0]+betas[4,0])**2 + (2*betas[5,0]+betas[4,0])**2) 
                    else:
                        return (np.squeeze( moments.T @ moments ) + (self.RTS-betas[1,0]-betas[2,0])**2)  
                else:
                    if self.translog:
                        return (np.squeeze( moments.T @ moments ) + (self.RTS-betas[0,0]-betas[1,0])**2 + (2*betas[2,0]+betas[3,0])**2 + (2*betas[4,0]+betas[3,0])**2) 
                    else:
                        return (np.squeeze( moments.T @ moments ) + (self.RTS-betas[0,0]-betas[1,0])**2)  
            else:
                if self.LogMoments:
                    return np.log( np.squeeze( moments.T @ moments ) )
                else:
                    return np.squeeze( moments.T @ moments )
    
    
    #----------------------------------------------------------# 
    def fit_GMM(self):
        """
        Brings self.moments() to 0 by using three possible optimizers.
        """
        t0 = time()
        
        self.optim_used = str(self.optim)
        
        ##Choose the optminzers
        if self.optim=='NM':
            ##Nelder-Mead simplex algorithm
            #NB: robust and good for large sample
            
            # compute the initial simplex
            if self.delta_simplex != 0:
                initial_simplex_our = np.zeros((len(self.b_ols)+1, len(self.b_ols))) # initialize the vector
                initial_simplex_our[0,:] = self.b_ols # the first vertex of the triangle is the initial value
                for i in range(len(self.b_ols)): # construct the other verteces
                    basis = np.zeros(len(self.b_ols)); basis[i] = 1 # construct the basis vector to add to the initial value
                    if len([self.delta_simplex])>1: initial_simplex_our[i+1,:] = self.b_ols + basis*self.delta_simplex[i]
                    else: initial_simplex_our[i+1,:] = self.b_ols + basis*self.delta_simplex
                
                if self.verbose:
                    print('')
                    print('Imputed initial simplex:')
                    print(initial_simplex_our)
                    print('')
                    
            else: initial_simplex_our = None
            
            # with LogMoments the function goes much faster to the minimum, but actually it never reaches it.
            # therefore with LogMoments the optimizer is raising errors, even when it is working
            if self.LogMoments: maxfun = 10000
            else: maxfun = 10000  
                
            optimizer = optimize.fmin
            #self.conv = optimizer( self.moments, self.b_ols, full_output=True,disp=verbose,xtol=10**(-6),ftol=10**(-7),maxfun=16000) #this is for fmin
            #%self.conv = optimizer( self.moments, self.b_ols, args=(optim,AR_c,AR_sqlag,), full_output=True,disp=verbose,xtol=10**(-6),ftol=10**(-7),maxfun=16000) #this is for fmin
#            self.conv = optimizer( self.moments, self.b_ols, initial_simplex=initial_simplex_our, full_output=True,disp=self.verbose,xtol=10**(-6),ftol=10**(-7),maxfun=2*16000) #this is for fmin
#            self.conv = optimizer( self.moments, self.b_ols, initial_simplex=initial_simplex_our, full_output=True,disp=self.verbose,xtol=10**(-6),ftol=10**(-7),maxfun=10000) #this is for fmin
            self.conv = optimizer( self.moments, self.b_ols, initial_simplex=initial_simplex_our, full_output=True,disp=self.verbose,xtol=10**(-6),ftol=10**(-7),maxfun=maxfun) #this is for fmin
            
            if self.verbose:
                print('Fun Value: ' + str(self.conv[1]))
                

        elif self.optim=='BFGS':
            ##quasi-Newton method of Broyden, Fletcher, Goldfarb, and Shanno (BFGS)
            #NB: bfgs algo (is good and precise on small sample but memory consuming not so good for large sample)
            optimizer = optimize.fmin_bfgs
            #self.conv = optimizer( self.moments, self.b_ols, full_output=True,disp=verbose,gtol=10**(-6)) #this is for fmin_bfgs
            self.conv = optimizer( self.moments, self.b_ols, full_output=True,disp=self.verbose,gtol=10**(-6)) #this is for fmin_bfgs
            
            if self.verbose:
                print('Fun Value: ' + str(self.conv[1]))
                
                
            
        elif self.optim=='FSOLVE':
            ##  wrapper around MINPACK’s hybrd and hybrj algorithms. 
            #NB: Find roots of a non linear equation, which is faster in most applications
            optimizer = optimize.fsolve
            self.conv = optimizer( self.moments, self.b_ols, full_output=True) #this is for fsolve
            if self.verbose:
                print(self.conv[-1])
                self.optim = 'NM'; print( 'Fun Value: ' + str(self.moments(self.conv[0])) ); self.optim = 'FSOLVE' #compute the value of the criteria for comparaison
                print('        fvec: ' + str(self.conv[-3]['fvec']))
                print('        nfev: ' + str(self.conv[-3]['nfev']))
                
        elif self.optim=='BASINHOPPING':
            ##Find the global minimum of a function using the basin-hopping algorithm
            optimizer = optimize.basinhopping
            
            ret= optimizer( self.moments, self.b_ols, disp=False ) #this is for basinhopping
            self.betas = ret.x
            self.ret=ret

            if self.verbose:
                print( ret.message[0] )
                print('Fun Value: ' + str(ret.fun))

        elif self.optim=='differential_evolution': # still work in progress, do not use
            ##Finds the global minimum of a multivariate function. Differential Evolution is stochastic in nature (does not use gradient methods) to find the minimum, and can search large areas of candidate space, but often requires larger numbers of function evaluations than conventional gradient-based techniques.
            optimizer = optimize.differential_evolution
            
            #define bouds for the solver
#            bounds = [(0,1)] * len(self.b_ols)
            delta = .2; bounds = [(x-delta,x+delta) for x in self.b_ols] # works well for CD
#            delta = .05; bounds = [(x-delta,x+delta) for x in self.b_ols]
            if self.delta_simplex!=0:
                for i in range(len(self.b_ols)):
                    bounds[i] = ( self.b_ols[i]-self.delta_simplex[i] , self.b_ols[i]+self.delta_simplex[i] )



            self.conv = optimizer( self.moments, bounds=bounds) 
            
            if self.verbose: 
                print('Bounds:')
                print(np.round(bounds,4))
                print('')
                print(self.conv)



        elif self.optim=='SLSQP' and self.RTS_const[0] and self.RTS_const[1]=='constraints':
            ## Minimize a scalar function of one or more variables using Sequential Least SQuares Programming (SLSQP)
            if self.GMM_cons:
                if self.translog:
                    eq_cons = {'type': 'eq',
                               'fun' : lambda x: np.array([x[1] + x[2] - self.RTS,
                                                          2*x[3] + x[4],
                                                          2*x[5] + x[4]]),
                               'jac' : lambda x: np.array([[0, 1, 1, 0, 0, 0],
                                                           [0, 0, 0, 2, 1, 0],
                                                           [0, 0, 0, 0, 1, 2]])};
                else:
                    eq_cons = {'type': 'eq',
                               'fun' : lambda x: np.array([x[1] + x[2] - self.RTS]),
                               'jac' : lambda x: np.array([0, 1, 1])};
            
            else:
                if self.translog:
                    eq_cons = {'type': 'eq',
                               'fun' : lambda x: np.array([x[0] + x[1] - self.RTS,
                                                          2*x[2] + x[3],
                                                          2*x[4] + x[3]]),
                               'jac' : lambda x: np.array([[1, 1, 0, 0, 0],
                                                           [0, 0, 2, 1, 0],
                                                           [0, 0, 0, 1, 2]])};
                else:
                    eq_cons = {'type': 'eq',
                               'fun' : lambda x: np.array([x[0] + x[1] - self.RTS]),
                               'jac' : lambda x: np.array([1, 1])};

            self.conv = optimize.minimize(self.moments, self.b_ols, method='SLSQP', constraints=eq_cons)
            
            if self.verbose:
                print( 'Message: ' + self.conv['message'] )
                print( 'Func val: ' + str(self.conv['fun']))

        elif self.optim=='DualAnnealing': # still work in progress, do not use
            ##Find the global minimum of a function using Dual Annealing
            optimizer = optimize.dual_annealing

            lw = [0] * len(self.b_ols)
            up = [2] * len(self.b_ols)            
#            minimizer_kwargs = {'args':(optim,AR_c,AR_sqlag,NormMoments,translog,RTS_const,AR_price,omega_price)}
#            ret= optimizer( self.moments, self.b_ols, minimizer_kwargs=minimizer_kwargs, disp=False ) #this is for basinhopping

            self.conv = optimizer( self.moments, bounds=list(zip(lw, up))) 
            
            print('')
            print(self.conv)

#            self.betas = ret.x
#            self.ret=ret

            if self.verbose:
                print( self.conv.message[0] )
                print('Fun Value: ' + str(ret.fun))
        
        
        elif self.optim=='Iter':
            
            not_converged = True # used to keep track of the procedure
            
            ## Try to run NM
            
            if self.verbose:
                print('')
                print('Attempting NM:')
        
            initial_simplex_our = None
                
            optimizer = optimize.fmin
            self.conv = optimizer( self.moments, self.b_ols, initial_simplex=initial_simplex_our, full_output=True,disp=self.verbose,xtol=10**(-6),ftol=10**(-7),maxfun=10000) #this is for fmin
            
            if self.verbose:
                print('Fun Value: ' + str(self.conv[1]))
            
            if self.conv[4] == 0:
                not_converged = False
                self.optim_used = 'NM'
            
            ## Try to run FSOLVE if NM did not converge
            
            if not_converged:

                if self.verbose:
                    print('')
                    print('Attempting FSOLVE:')
                    
                optimizer = optimize.fsolve
                self.optim = 'FSOLVE'
                self.conv = optimizer( self.moments, self.b_ols, full_output=True) #this is for fsolve
                self.optim = 'Iter'
                if self.verbose:
                    print(self.conv[-1])
                    self.optim = 'NM'; print( 'Fun Value: ' + str(self.moments(self.conv[0])) ); self.optim = 'FSOLVE' #compute the value of the criteria for comparaison
                    print('        fvec: ' + str(self.conv[-3]['fvec']))
                    print('        nfev: ' + str(self.conv[-3]['nfev']))
                    
                if self.conv[-2] == 1:
                    not_converged = False
                    self.optim_used = 'FSOLVE'
                    
            ## Try to run BASINHOPPING
            
            if not_converged:

                if self.verbose:
                    print('')
                    print('Attempting BASINHOPPING:')
                    
                optimizer = optimize.basinhopping
            
                ret= optimizer( self.moments, self.b_ols, disp=False ) #this is for basinhopping
                self.betas = ret.x
                self.ret=ret
    
                if self.verbose:
                    print( ret.message[0] )
                    print('Fun Value: ' + str(ret.fun))
                    
                if self.ret.lowest_optimization_result.success == True:
                    not_converged = False
                else:
                    if self.verbose:
                        print('')
                        print('CONVERGENCE NOT ACHIEVED!')
                        print('')
                
                self.optim_used = 'BASINHOPPING'
                
            if self.verbose:
                print('') # do it to make the output look good
                
                
                
        t1 = time()
        
        # the coefficients betas
         # the coefficients betas
        if self.optim_used=='BASINHOPPING':
            pass
        elif self.optim_used=='SLSQP' or self.optim_used=='differential_evolution':
            self.betas=self.conv['x']
        else:
            self.betas=self.conv[0]
        
        if self.RTS_const[0] and self.RTS_const[1]=='explicit': # apply Traina constant return to scale correction
            if self.GMM_cons:
                if self.translog:
                    self.betas = np.array([self.betas[0],self.betas[1],self.RTS-self.betas[1],self.betas[2],-2*self.betas[2],self.betas[2]])
                else:
                    self.betas = np.array([self.betas[0],self.betas[1],self.RTS-self.betas[1]])
            else:
                if self.translog:
                    self.betas = np.array([self.betas[0],self.RTS-self.betas[0],self.betas[1],-2*self.betas[1],self.betas[1]])
                else:
                    self.betas = np.array([self.betas[0],self.RTS-self.betas[0]])
                
        if self.verbose:
            print('Fitting time: ' + str(np.round(t1-t0,2)) + ' s')
            print('Resulting Betas:')
            print(self.betas)
        
        return self.betas
    

    #----------------------------------------------------------# 
    def plot_moments(self):
        """
        Plots in 3d the moment function, the one that needs to be minimized
        """
        
        if self.verbose:
            print('Plotting the moment function...')
        
        #deep parameters for this method
        
        
#        ngrid=200 #number of grid point on each dimension
#        cmap = plt.get_cmap('ocean') #colormap
#        NbNorm = 20 #numbers of colorclass
#        xlim = [-2, 2]
#        ylim = [-2, 2]
#        mesh=True
        xstar= self.betas[self.plot_moments_fn_coeff[0]]
        ystar= self.betas[self.plot_moments_fn_coeff[1]]
        x0= self.b_ols[self.plot_moments_fn_coeff[0]]
        y0= self.b_ols[self.plot_moments_fn_coeff[1]]
        
        ngrid=self.plot_moments_fn_param['ngrid'] #number of grid point on each dimension
#        ngrid=101 #number of grid point on each dimension
        cmap = plt.get_cmap('ocean') #colormap
        NbNorm = self.plot_moments_fn_param['NbNorm'] #numbers of colorclass
        mesh= self.plot_moments_fn_param['mesh']
        
        xlim = self.plot_moments_fn_param['xlim']
        ylim = self.plot_moments_fn_param['ylim']
        
        if xlim==[]: 
#            xint = abs(xstar-x0)*1.5
#            xint = abs(xstar)*0.5
            xint= max(abs(xstar-x0)*1.5, abs(xstar)*0.5 )
            xlim = [ xstar - xint, xstar + xint]
        
        if ylim==[]:
#            yint = abs(ystar-y0)*1.5
#            yint = abs(ystar)*0.5
            yint= max( abs(ystar-y0)*1.5, abs(ystar)*0.5 )
            ylim = [ ystar - yint, ystar + yint]
            
        
        
        
        # define the grid
        x = np.linspace(xlim[0], xlim[1], ngrid)
        y = np.linspace(ylim[0], ylim[1], ngrid)

        self.X_3D, self.Y_3D = np.meshgrid(x, y)
        self.Z_3D = np.zeros((ngrid,ngrid))
        # compute the moment function on the grid
        for i in range(ngrid):
            for j in range(ngrid):
                betas_3D = np.array([x[i],y[j]])
                betas_3D = np.concatenate( [self.betas[0:self.plot_moments_fn_coeff[0]],
                                            betas_3D[0:1], 
                                            self.betas[self.plot_moments_fn_coeff[0]+1:self.plot_moments_fn_coeff[1]],
                                            betas_3D[1:2],
                                            self.betas[self.plot_moments_fn_coeff[1]+1:]                                            
                                            ]) # this is done to accomodate for more than 2 betas

                if self.optim_used=='FSOLVE': self.Z_3D[i,j] = self.moments(betas_3D).T @ self.moments(betas_3D) 
                else: self.Z_3D[i,j] = self.moments(betas_3D)
                
                
#                self.Z_3D[i,j] = float(np.sum( self.moments(betas_3D) )) # this is done to accomodate for FSOLVE

        if self.LogMoments==False:
            self.Z_3D = np.log(self.Z_3D) # plot in logs all the time
        
        #### plot the figure in 3D
        fig = plt.figure(figsize=(10,10))
        ax = plt.axes(projection='3d')
#        ax.contour3D(self.Y_3D, self.X_3D, self.Z_3D, 50, cmap='binary') # plot of upper-countour lines
#        ax.plot_wireframe(self.Y_3D, self.X_3D, self.Z_3D, color='black') # plot of wireframe
#        ax.plot_surface(self.X_3D, self.Y_3D, np.log(self.Z_3D), rstride=1, cstride=1, cmap=cmap, edgecolor='none') # plot of surface with colors :D
        ax.plot_surface(self.X_3D, self.Y_3D, self.Z_3D, rstride=1, cstride=1, cmap=cmap, edgecolor='none') # plot of surface with colors :D
        #ax.set_xlabel('beta '+str([k for k in self.vars_how if self.vars_how[k][1]!='no input'][0]))
        #ax.set_ylabel('beta '+str([k for k in self.vars_how if self.vars_how[k][1]!='no input'][1]))
        ax.set_xlabel(r'$\beta_{'+self.variables[self.plot_moments_fn_coeff[0]]+'}$')
        ax.set_ylabel(r'$\beta_{'+self.variables[self.plot_moments_fn_coeff[1]]+'}$')
        ax.set_zlabel('Moment Function (log)')
        ax.set_title('Moment Function in 3D')
        
        ###########
        plt.show(block=False)
        plt.close()
        
        self.plt_3D = fig # save the heatmap as an object in the class

        #plt.imshow(self.Z_3D, cmap='viridis')
        #plt.colorbar()
        #plt.show()
        

        #### plot an heatmap
        fig, ax = plt.subplots(figsize=(10,10))
        levels= np.linspace(self.Z_3D.min(), self.Z_3D.max(), NbNorm)
        norm = pltcolors.BoundaryNorm(boundaries=levels, ncolors=cmap.N, clip=True)
        if mesh:
#            c = ax.pcolormesh(self.X_3D, self.Y_3D, np.log(self.Z_3D), cmap=cmap,norm=norm)
            c = ax.pcolormesh(self.X_3D, self.Y_3D, self.Z_3D, cmap=cmap,norm=norm)
        else:
#            c = ax.contourf(self.X_3D, self.Y_3D, np.log(self.Z_3D), cmap=cmap,norm=norm)
            c = ax.contourf(self.X_3D, self.Y_3D, self.Z_3D, cmap=cmap,norm=norm)
        ax.set_title('Heatmap of Moment Function')
        #ax.set_xlabel('beta '+str([k for k in self.vars_how if self.vars_how[k][1]!='no input'][0]))
        #ax.set_ylabel('beta '+str([k for k in self.vars_how if self.vars_how[k][1]!='no input'][1]))
        ax.set_xlabel(r'$\beta_{'+self.variables[self.plot_moments_fn_coeff[0]]+'}$')
        ax.set_ylabel(r'$\beta_{'+self.variables[self.plot_moments_fn_coeff[1]]+'}$')


        # mark estimated betas with a star 
#        ax.plot(self.betas[self.plot_moments_fn_coeff[0]],self.betas[self.plot_moments_fn_coeff[1]],'k*',ms=10)
        ax.plot(self.betas[self.plot_moments_fn_coeff[0]],self.betas[self.plot_moments_fn_coeff[1]],'r*',ms=10)
            

        # mark initial values with a cross
#        ax.plot(self.b_ols[self.plot_moments_fn_coeff[0]],self.b_ols[self.plot_moments_fn_coeff[1]],'kX',ms=10)
        ax.plot(self.b_ols[self.plot_moments_fn_coeff[0]],self.b_ols[self.plot_moments_fn_coeff[1]],'rX',ms=10)



        if self.beta_true!=[]:
            ax.annotate("True",
                        xy=self.beta_true,
                        xytext=self.beta_true + np.array([-0.1,0.1]),
                        arrowprops=dict(facecolor="black", width=0.5,headwidth=4, shrink=0.1))

        ax.axis([self.X_3D.min(), self.X_3D.max(), self.Y_3D.min(), self.Y_3D.max()])
        fig.colorbar(c, ax=ax)
        #########
        plt.show(block=False)
        plt.close()
        
        self.plt_heatmap = fig # save the heatmap as an object in the class
        
        # get the minumum on the grid
#        ind_min = np.unravel_index(np.argmin(self.Z_3D, axis=None), self.Z_3D.shape)
#        if self.verbose: print('F at minimum = ' + str(self.Z_3D[ind_min[0],ind_min[1]]))
        
 
                
        
        

    #----------------------------------------------------------# 
    def compute_markups(self):
        """
        Calculate markups given estimated elasticities. Note that
        there is going to be one mark up per variable input.
        Relative to the baseline class we thus need to adjust for that.
        """
        
        ## construct the elasticity
        
        if self.sales_correction:
            sales_corr_term = np.exp(self.dta_mu['epsilon']);
        else:
            sales_corr_term = 1;
        
        for key in self.vars_how:
            if self.vars_how[key][2] == 'variable':
                
                elasticity = self.betas[self.variables.index(key)]
                #print(key)
                #print(elasticity.mean())
                
                
                if self.translog: # add squared terms in case of translog
                    elasticity = elasticity + 2*self.betas[self.variables.index(key+'2')]*self.dta_mu[key]
                    #print(key+'2')
                    #print(elasticity.describe())
                    #print(self.dta_mu[key].describe())
                    #print(' beta' + key+'2')
                    #print(self.betas[self.variables.index(key+'2')])
                    
                    if self.tl_inter: # add interaction term in case of translog
                        for key2 in self.vars_how:
                            if self.vars_how[key2][1] != 'no input' and key2 != key:
                                try:
                                    elasticity = elasticity + self.betas[self.variables.index(key+key2)]*self.dta_mu[key2]
                                    #print(key2)
                                    #print(elasticity.mean())

                                except:
                                    elasticity = elasticity + self.betas[self.variables.index(key2+key)]*self.dta_mu[key2]
                                    #print(key2)
                                    #print(elasticity.mean())

        
                #
                #self.dta_mu['mu_' + key] = elasticity * (np.exp(self.dta_mu['y']/sales_corr_term) / np.exp(self.dta_mu[key]))
                self.dta_mu['mu_' + key] = elasticity * (self.dta_mu['sale']/sales_corr_term) / np.exp(self.dta_mu[key])
        
                      
    #----------------------------------------------------------# 
    def clean_dta_mu(self):
        """
        Clean the output table to get nice results.
        """
#        with warnings.catch_warnings(): # suppress wornings of log(0)
#            warnings.simplefilter("ignore")
        self.dta_mu = self.dta_mu.drop(columns=['date'])
        self.dta_mu = self.dta_mu.reset_index()

        
        vars_keep = ['date','firmid','sale','varcost','sector_2d','epsilon']
        for key in self.vars_how:
            if self.vars_how[key][2] == 'variable':
                vars_keep.append('mu_' + key)
                
#        self.dta_mu = self.dta_mu[vars_keep]
        
        if self.verbose:
            for key in self.vars_how:
                if self.vars_how[key][2] == 'variable':
                    mu = 'mu_' + key
#                    mean_mu   = np.round(np.mean(self.dta_mu[mu],3)
#                    median_mu = np.round(np.median(self.dta_mu[mu],3)
                    #mean_mu   = np.round(np.mean(self.dta_mu[mu].dropna()),3)
                    median_mu = np.round( np.median(self.dta_mu[mu].dropna()) , 3)
                    iqr_mu = np.round( ( np.quantile(self.dta_mu[mu].dropna(),0.75) - np.quantile(self.dta_mu[mu].dropna(),0.25) ) / np.quantile(self.dta_mu[mu].dropna(),0.5)    ,3)
                    
                    #print('Variable {} has {} mean and {} median markup'.format(key, mean_mu, median_mu))
                    print('Variable {} has {} median and {} iqr markup'.format(key, median_mu, iqr_mu))

    
    def boot_block_resample(self):
        """ Block resampling with replacement 

            Sample with replacement the whole population
            of firms preserving autocorrelation
        """
        # horrible hack since class is not modular
        # need to delete 'date' since later readded
#        if 'date' in self.dta.columns:
#            dat = self.dta.drop(columns='date')
#        firmids = list(dat.index.get_level_values('firmid'))
#        
#        # select x% of unique firms in dataset for bootstrap
#        nfirms = len(set(firmids))
#        sb = nfirms
#
#        # block bootstrap with replacement
#        bootfirms = np.random.choice(firmids, sb, replace=True)
#        idx_boot = dat.index.isin(bootfirms, level='firmid')
#        boot_sample = dat[idx_boot]
#
#        # add index columns
#        # here Pandas complain about a SettingWithCopyWarning
#        # used .loc as recommended but Pandas still complains
#        firms = boot_sample.index.get_level_values('firmid')
#        dates = boot_sample.index.get_level_values('date')
#        boot_sample.loc[:,'firmid'] = list(firms)
#        boot_sample.loc[:,'date'] = list(dates)
#        boot_sample.reset_index(drop=True, inplace=True)
#        

        dat = self.dta_mu.copy()
        # drop columns with lag (because they have missing values)
        dat = dat.drop(columns = [c for c in dat.columns if 'lag' in c])
        
        firmids = list(dat.firmid.unique()) # list of firm ids
        sb = len(set(firmids)) # extract an equal number of firms
#        sb = 5 * len(set(firmids)) # extract an equal number of firms
        bootfirms = np.random.choice(firmids, sb, replace=True) # extract the random sample (with replace)
        
        boot_sample = pd.DataFrame(bootfirms, columns = ['firmid']) # init the sample
        boot_sample['new_id'] = boot_sample.index # create this variabel for re-indexing firms
        
        boot_sample = boot_sample.merge(dat, how='left', on='firmid') # merge with the original dataset
        
        # take care of firmid, ensureing that each booted firm has unique id
        boot_sample = boot_sample.drop(columns = ['firmid'])
        boot_sample = boot_sample.rename(columns = {'new_id':'firmid'}) 

        return boot_sample 

    def boot_se(self):
        """ Bootstrap SE

            Input
            -----
            nb : int, number of bootstrap repetition  
            alpha_se : float, two-sided confidence level in [0, 1] 

            Output
            ------
            beta_se : numpy 1D array, bootstrap SE for beta 
                        (elasticities)
            CI : numpy 2d array, CI for betas.
                 1st column is lower endpoint CI
                 and 2nd column is upper endpoint CI            
        """  

        pd.options.mode.chained_assignment = None  # suppress warnings, for clean output
        
        if self.verbose:
            print('Running bootstrap...')
            
        t0 = time()
        
        lbetas = [None] * self.nb
        for i in range(self.nb):  # 
            # if i % 10 == 0:
#                print(i)       
            # draw sample from CLEANED dataset
            # and store subsample in self.dta
            # since all methods utilize self.dta
            subsample = self.boot_block_resample()   
            
            P_se = ProdFun(subsample,
                           vars_how = self.vars_how,
                           M=self.M,
                           purging=self.purging,
                           phi_true=self.phi_true,
                           FE=self.FE,
                           FE_demean = self.FE_demean,
                           translog = self.translog,
                           tl_inter = self.tl_inter,
                           GMM_cons = self.GMM_cons,
                           beta_init = self.betas,
                           #beta_init = self.beta_init,
                           delta_simplex = self.delta_simplex,
                           beta_true = self.beta_true,
                           init_control = self.init_control,
                           optim=self.optim,
                           AR_c=self.AR_c,
                           AR_sqlag=self.AR_sqlag,
                           AR_price = self.AR_price,
                           omega_price = self.omega_price,
                           NormMoments = self.NormMoments,
                           LogMoments = self.LogMoments,
                           RTS_const = self.RTS_const,
                           RTS=self.RTS,
                           subsample=self.subsample,
                           markups = False,
                           verbose=False,
                           se=False,
                           )

            
            
            # store fitted betas and reset dta to
#            lbetas[i] = P_se.betas
            if P_se.optim_used == 'NM':
#                if P_se.conv[1] < 10**(-6): # control that the function is close enough to 0
#                if P_se.conv[1] < 10 * self.conv[1]: # control that the function is close enough to 0
#                if P_se.conv[4] == 0: # control that the optimization raised no warning
                if P_se.conv[1] < 10 * self.conv[1] and P_se.conv[4] == 0: # control that the function is close enough to 0
                    lbetas[i] = P_se.betas
#                if self.LogMoments:
#                    if P_se.conv[1] < min([self.conv[1]+10,10]): # control that moments are close to 0
##                    if P_se.conv[1] < self.conv[1]*10: # control that moments are close to 0
#                        lbetas[i] = P_se.betas
#                else:
#                    if P_se.conv[4] == 0 and P_se.conv[1] < self.conv[1]*10: # control that the optimization raised no warning
#                        lbetas[i] = P_se.betas
#                else:
#                    print('Bootstrap repetition ' + str(i) + ' discarded')                    
            elif P_se.optim_used == 'FSOLVE':
#                if P_se.conv[-3]['fvec'].sum() < 10**(-6): # control that the function is close enough to 0
#                if P_se.conv[-3]['fvec'].sum() < 10 * self.conv[-3]['fvec'].sum(): # control that the function is close enough to 0
                if P_se.conv[-2] == 1: # control that a solution was found
#                if P_se.conv[-2] == 1 and P_se.conv[-3]['fvec'].sum() < 10**(-6): # control that a solution was found
#                if P_se.conv[-2] == 1 and P_se.conv[-3]['fvec'].sum() < 2 * self.conv[-3]['fvec'].sum(): # control that a solution was found
                    lbetas[i] = P_se.betas
#                else:
#                    print('Bootstrap repetition ' + str(i) + ' discarded')                    
            elif P_se.optim_used == 'BASINHOPPING':
#                if P_se.ret.fun < 10**(-6): # control that the function is close enough to 0
#                if P_se.ret.fun < 10 * self.ret.fun: # control that the function is close enough to 0
                if P_se.ret.lowest_optimization_result.success == True: # control that a solution was found
#                if P_se.ret.lowest_optimization_result.success == True and P_se.ret.fun < 10**(-6): # control that a solution was found
#                if P_se.ret.lowest_optimization_result.success == True and P_se.ret.fun < 10 * self.ret.fun: # control that a solution was found
                    lbetas[i] = P_se.betas
#                else:
#                    print('Bootstrap repetition ' + str(i) + ' discarded')                    
            else:
                lbetas[i] = P_se.betas
            
        lbetas = [l for l in lbetas if l is not None] # drop discarded repetitions
        if self.verbose:
            print('Using '+str(len(lbetas))+'/'+str(self.nb)+' bootstrap repetitions')
        
        pd.options.mode.chained_assignment = 'warn'  # return to default='warn'
        
        # compute SEs and CI for betas
        BBetas = np.array(lbetas)   # bootstrap betas
        betas_se = np.std(BBetas, 0)  # SEs
        z = stats.norm.ppf(self.alpha_se/2)
        CI = np.array([self.betas + z*betas_se,   # low CI
                       self.betas - z*betas_se])  # high CI

        betas_cov = np.cov(BBetas, rowvar=False)
        
        ##test joint significance of TL coefficients
        if self.translog:
            
            #get numbers for the test
            n_CD = len([key for key in self.vars_how if self.vars_how[key][1]!='no input']) # number of CD coeff
            if self.GMM_cons: n_CD = n_CD + 1 # add the constant
            n_TL = len(self.variables) - n_CD # number of TL coeff (hypothesis to test)
            n = int(self.Y.size) # number of observations
            
            betas_2d = np.array([self.betas[n_CD:]]).T
            
            ## Wald Statistic
            # with estimated covariance matrix
#            W = betas_2d.T @ np.linalg.inv(betas_cov[n_CD:,n_CD:] / n) @ betas_2d / n_TL
            W = betas_2d.T @ np.linalg.inv(betas_cov[n_CD:,n_CD:]) @ betas_2d / n_TL
            p_value = 1 - stats.f.cdf(W, n_TL, n-n_CD-n_TL)
            TL_test = {'p_value': float(p_value), 'WaldStatistic': float(W)} # dictionary to store the results
            
        
        t1 = time()
        if self.verbose:
            print('Elapsed time: ' + str(np.round(t1-t0,2)))
            print('Resulting Betas SE:')
            print(betas_se)
            if self.translog:
                print('Joint significance of translog coefficients:')
                print(TL_test)
            
            
        # store as attribute since 
        # method should be run when initialized...
        self.ses = betas_se
        self.CI = CI.T
        self.BBetas = BBetas
        if self.translog:
            self.betas_cov = betas_cov
            self.TL_test = TL_test
        
        
        
        return betas_se, CI.T
